{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0IwlAjlibuUE"
      },
      "source": [
        "# Understanding and Manipulating Data in Waymax\n",
        "\n",
        "This tutorial covers data structures in Waymax."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1NTUsEy_btuU"
      },
      "outputs": [],
      "source": [
        "%%capture\n",
        "import dataclasses\n",
        "import jax\n",
        "from jax import numpy as jnp\n",
        "from matplotlib import pyplot as plt\n",
        "\n",
        "from waymax import config as _config\n",
        "from waymax import dataloader\n",
        "from waymax import datatypes"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2pcSuBb9cFZN"
      },
      "outputs": [],
      "source": [
        "# Load example data.\n",
        "config = dataclasses.replace(_config.WOD_1_1_0_VALIDATION, max_num_objects=32)\n",
        "data_iter = dataloader.simulator_state_generator(config=config)\n",
        "scenario = next(data_iter)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GZAFy24xcLQc"
      },
      "source": [
        "## JAX-based data structures\n",
        "\n",
        "The key property with all data structures in Waymax is that all data is *immutable*. This is a design decision that is inherited from JAX and enables code written using Waymax to be compatible with the powerful functional transforms in JAX, such as `jit`, `vmap`, etc. While efficiency may be a concern with immutable data structures, wrapping your function in `jax.jit` will allow JAX to optimize and replace your operations with in-place operations wherever possible, avoiding the need for excessive data copying.\n",
        "\n",
        "Additionally, all datastructures in Waymax are implemented as dataclasses. This allows convenient named access to fields, and allows simple nesting of data structures that is easy to manipulate with tree-based operations (such as those in `jax.tree_util`).\n",
        "\n",
        "The first example we will cover is the `datatypes.Trajectory` data structure, which holds the pose information for all objects. The scenario that we loaded contains a trajectory containing the logged behavior for all agents under the `scenario.log_trajectory` attribute."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sPMaI3ShdPwB"
      },
      "outputs": [],
      "source": [
        "log_trajectory = scenario.log_trajectory\n",
        "\n",
        "# Number of objects stored in this trajectory.\n",
        "print('Number of objects:', log_trajectory.num_objects)\n",
        "print('Number of timesteps:', log_trajectory.num_timesteps)\n",
        "print('Trajectory shape (num_objects, num_timesteps):', log_trajectory.shape)\n",
        "print('XYZ positions (num_objects, num_timesteps, 3):', log_trajectory.xyz.shape)\n",
        "print('XY velocities (num_objects, num_timesteps, 2):', log_trajectory.vel_xy.shape)\n",
        "print('Yaw (num_objects, num_timesteps):', log_trajectory.yaw.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nsKEO7GCelBK"
      },
      "source": [
        "The `datatypes` module contains some helper methods that automatically map over datastructures. We can use `datatypes.dynamic_slice` to select out the trajectory belonging to a particular object or at a particular timestep. These operations, as with all JAX operations, will return new copies of the object they are modifying. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "k6cWGuVRdlQk"
      },
      "outputs": [],
      "source": [
        "# Slice by time. Select the trajectory at timestep 23.\n",
        "traj_t23 = datatypes.dynamic_slice(log_trajectory, start_index=23, slice_size=1, axis=-1)\n",
        "print('XYZ positions (num_objects, 1, 3):', traj_t23.xyz.shape)\n",
        "\n",
        "# Slice by object. Select the trajectory for object 15.\n",
        "traj_obj15 = datatypes.dynamic_slice(log_trajectory, start_index=15, slice_size=1, axis=-2)\n",
        "print('XYZ positions (1, num_timesteps, 3):', traj_obj15.xyz.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LXrPjhzPf7EJ"
      },
      "source": [
        "Of course, JAX functions from the core library also work on Waymax data structures. The `tree_map` function is particularly useful for working with dataclasses, and will apply a single function to all fields in the data structure (recursively if there are nested data structures)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "shVRPmToe9oY"
      },
      "outputs": [],
      "source": [
        "def max_along_time(x: jax.Array) -\u003e jax.Array:\n",
        "  return jnp.max(x, axis=-1, keepdims=True)\n",
        "\n",
        "max_trajectory = jax.tree_util.tree_map(max_along_time, log_trajectory)\n",
        "print('XYZ positions (num_objects, 1, 3):', max_trajectory.xyz.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mlPkVLA7Mcx-"
      },
      "source": [
        "To modify the values of the data structure, we can use `dataclasses.replace` to replace entire fields, and `Array.at[idx].set(value)` to selectively modify individual values. For example, to set the all yaws for object 1 to zero, we can use the following code snippet:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Z5napP9gMutj"
      },
      "outputs": [],
      "source": [
        "zeroed_traj = dataclasses.replace(\n",
        "    log_trajectory, \n",
        "    yaw=log_trajectory.yaw.at[1].set(0.0)\n",
        ")\n",
        "\n",
        "# Should be the original values.\n",
        "print('Yaws for object 0, timesteps 0 to 5:', zeroed_traj.yaw[0, 0:5])\n",
        "\n",
        "# Should be now set to 0.\n",
        "print('Yaws for object 1, timesteps 0 to 5:', zeroed_traj.yaw[1, 0:5])\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "E7DZS_BKhZLj"
      },
      "source": [
        "# Other important data structures\n",
        "\n",
        "We will now cover the remaining important data structures that are stored in a scenario.\n",
        "\n",
        "The `datatypes.RoadgraphPoints` data structure holds all static information regarding the road and environment. This includes all lanes markers, road edges, stop signs, speed bumps, and crosswalks. \n",
        "- The `x`, `y`, and `z` attributes define the spatial coordinates of the points.\n",
        "- The `type` attribute is an integer that defines what type of point (lane, edge, stop sign, etc.) the point is. See `roadgraph_samples/type` of the [Waymo Open Motion Dataset](https://waymo.com/open/data/motion/tfexample) for definitions of which value corresponds to what type of point.\n",
        "- The `dir_x` and `dir_y` attributes define the orientation of the points. Lane points will orient in the forward direction of the lane. Edge points are oriented such that the inside of the road is always on the port side (left if facing forward) of the point.\n",
        "- The `id` field is a unique identifier for each contiguous lane. Lanes end if there is an intersection or reach the edge of the map."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "k9XQmInAh4vo"
      },
      "outputs": [],
      "source": [
        "# Plot the roadgraph, with colors corresponding to the road type.\n",
        "rg_points = scenario.roadgraph_points\n",
        "\n",
        "where_valid = rg_points.valid\n",
        "plt.scatter(\n",
        "    x = rg_points.x[where_valid],\n",
        "    y = rg_points.y[where_valid],\n",
        "    s=0.1,\n",
        "    c = rg_points.types[where_valid]\n",
        ")\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2MdtujKbjahp"
      },
      "source": [
        "The `datatypes.TrafficLights` structure holds time-varying information regarding the color and position of the traffic lights.\n",
        "- The `x`, `y`, and `z` attributes define the spatial location of the light.\n",
        "- The `state` attribute defines what color the light is at a particular instance in time.\n",
        "- The `lane_ids` attribute tells what lanes the traffic light is controlling. These can be cross-referenced with the `RoadgraphPoints.ids` field."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zse6OxgaiEbH"
      },
      "outputs": [],
      "source": [
        "traffic_lights = scenario.log_traffic_light\n",
        "\n",
        "print('Traffic Light States (num_lights, num_timesteps):', traffic_lights.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QY-rhD1Qkr-C"
      },
      "source": [
        "Finally, `datatypes.ObjectMetadata` holds \n",
        "- The `object_types` attribute defines whether the object is a vehicle, pedestrian, or cyclist.\n",
        "- The `ids` attribute assigns a unique ID to each object.\n",
        "- The `is_sdc` attribute defines whether the object is the ego-vehicle (or self-driving car).\n",
        "- The `is_modeled` attribute marks whether the object's behavior is meant to be predicted as part of the Waymo Open Motion dataset.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Dohh8d9Rkh70"
      },
      "outputs": [],
      "source": [
        "metadata = scenario.object_metadata\n",
        "print('All object IDS:', metadata.ids)\n",
        "\n",
        "# Color-code object trajectory by whether it is the SDC or not.\n",
        "# The SDC trajectory in the center is shown in blue, and all other trajectories\n",
        "# are shown in red.\n",
        "flat_trajectory = jax.tree_util.tree_map(lambda x: jnp.reshape(x, [-1]), log_trajectory)\n",
        "colors = jnp.zeros(log_trajectory.shape, dtype=jnp.int32).at[metadata.is_sdc].set(1)\n",
        "colors = jnp.reshape(colors, [-1])\n",
        "\n",
        "where_valid = flat_trajectory.valid\n",
        "plt.scatter(\n",
        "  x=flat_trajectory.x[where_valid],\n",
        "  y=flat_trajectory.y[where_valid],\n",
        "  s=0.5,\n",
        "  c=colors[where_valid],\n",
        "  cmap='RdYlBu'\n",
        ")\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hmNQVj-clfIr"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {
        "build_target": "",
        "kind": "local"
      },
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1Ee2coesuhg82e7E9HmZxafZ7raD_ug5q",
          "timestamp": 1683520258687
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
